from code_template import CodeTemplate

FILE = CodeTemplate("""\
#include "TH/TH.h"
#ifdef AT_CUDA_ENABLED
#undef THNN_
#include "THC/THC.h"
#endif
#include "ATen/Utils.h"
${copy_includes}

namespace at {

${copy_functions}

}
""")

CASE = CodeTemplate("""\
case ${src_id}:
    ${THTensor}_copy${cuda}${src_scalar_name}(${state,}dst_->tensor,static_cast<${src_tensor}*>(src.pImpl)->tensor);
    break;
""")

FUNCTION = CodeTemplate("""\
void ${Type}::s_copy(const Tensor & src, Tensor & dst) const {
  // code generated by function_wrapper
  auto dst_ = checked_cast_tensor<${Tensor}>(dst.pImpl,"dst",0,false);
  (void) dst_; //silence unused warning
  switch(src.type().ID()) {
    ${copy_body}
    default:
      runtime_error("copy does not support %s to %s copy.",src.type().toString(),toString());
      break;
  }
  dst.pImpl->setScalar(src.pImpl->isScalar());
}
""")


def create_one(env, all_types):
    copy_body = []
    for src_type in all_types:
        if env['Density'] == 'Sparse' or src_type['Density'] == 'Sparse':
            # skip sparse copies, which are not yet implemented
            continue
        state = []
        cuda = ''
        if src_type['Backend'] == 'CUDA':
            cuda = 'Cuda'
        if env['Backend'] == 'CUDA' or src_type['Backend'] == 'CUDA':
            state.append('context->thc_state')
        copy_body.append(CASE.substitute(env,
                                         src_scalar_name=src_type['ScalarName'],
                                         src_id=src_type['TypeID'],
                                         src_tensor=src_type['Tensor'],
                                         cuda=cuda,
                                         state=state,
                                         ))
    return FUNCTION.substitute(env, copy_body=copy_body)


def create(all_types):

    top_env = {
        'copy_includes': [],
        'copy_functions': [],
    }
    for dst_type in all_types:
        top_env['copy_includes'].append(
            '#include "ATen/{}.h"'.format(dst_type['Type']))
        top_env['copy_includes'].append(
            '#include "ATen/{}.h"'.format(dst_type['Tensor']))
        top_env['copy_functions'].append(create_one(dst_type, all_types))
    return FILE.substitute(top_env)
